#!/usr/bin/env python3
'''
Author: Paschalis Agapitos
Project: Mestizajes

Provides analysis utilities for evaluating temporal clustering patterns in
century-level embedding spaces. The module reduces embedding representations
with PCA, measures how strongly historical periods form clusters using
supervised metrics and permutation testing, and visualizes hierarchical
relationships between centuries through dendrograms.
'''

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import silhouette_score, calinski_harabasz_score
from scipy.cluster.hierarchy import dendrogram, linkage
from typing import Dict, Callable, Optional, List, Tuple
from pathlib import Path

def _default_get_period(century: int) -> str:
    """Default mapping from century to historical period."""
    if century in {-5, -4, -3, -2, -1, 1, 2, 3, 4}:
        return "Classical Antiquity"
    elif 5 <= century <= 14:
        return "Middle Ages"
    elif century in {15, 16, 17}:
        return "Renaissance"
    elif century == 18:
        return "Enlightenment"
    elif century in {19, 20}:
        return "Modern Era"
    return "Other"


plt.style.use("ggplot")
sns.set_palette("colorblind")


def assess_clustering_structure_3d(
    century_embeddings: Dict[int, np.ndarray],
    get_period_fn: Optional[Callable[[int], str]] = None,
    n_pca_components: int = 3,
    n_permutations: int = 1000,
    random_state: int = 42,
    save_dir: Optional[Path] = None,
    file_tag: str = "analysis",
    show_plot: bool = True
) -> Dict[str, float]:
    """
    Projects century embeddings to 3D PCA space and evaluates clustering structure.

    Performs two types of analysis:
    1. Supervised: Evaluates how well 'Periods' form clusters using Silhouette Analysis and 
       tests statistical significance via a Permutation Test against a null model.
    2. Unsupervised (Hierarchical): Plots a dendrogram to visualize the natural hierarchy of temporal proximity.

    Parameters
    ----------
    century_embeddings : Dict[int, np.ndarray]
        Dictionary mapping century (int) to embedding vector.
    get_period_fn : Callable[[int], str], optional
        Function to map century to period name. Uses default if None.
    n_pca_components : int
        Number of PCA components to project to (usually 3).
    n_permutations : int
        Number of permutations for the label randomization test.
    random_state : int
        Random seed.
    save_dir : Path, optional
        Directory to save plot.
    file_tag : str
        Tag for filename.
    show_plot : bool
        Whether to show the plot.

    Returns
    -------
    Dict[str, float]
        Dictionary containing the supervised evaluation metrics.
    """
    if get_period_fn is None:
        get_period_fn = _default_get_period

    # 1. Prepare Data
    centuries = sorted(century_embeddings.keys())
    X_raw = np.array([century_embeddings[c] for c in centuries])
    
    # 2. PCA Projection to 3D
    # Standardize first (consistent with plot_centuries_and_persons)
    scaler = StandardScaler(with_mean=True, with_std=True)
    X_scaled = scaler.fit_transform(X_raw)
    
    pca = PCA(n_components=n_pca_components, random_state=random_state)
    X_pca = pca.fit_transform(X_scaled)
    
    print(f"PCA ({n_pca_components}D) Explained Variance Ratio: {pca.explained_variance_ratio_}")
    print(f"Total Explained Variance: {np.sum(pca.explained_variance_ratio_):.4f}")

    # 3. Supervised Evaluation (Periods) with Permutation Test
    period_labels = [get_period_fn(c) for c in centuries]
    unique_periods = list(set(period_labels))
    
    metrics = {}
    
    print("\n--- Period Clustering Evaluation (Supervised) ---")
    if len(unique_periods) > 1:
        # Observed metrics
        sil_score_periods = silhouette_score(X_pca, period_labels)
        ch_score_periods = calinski_harabasz_score(X_pca, period_labels)
        
        metrics["silhouette_score"] = sil_score_periods
        metrics["calinski_harabasz_score"] = ch_score_periods
        
        print(f"Number of Periods: {len(unique_periods)}")
        print(f"Observed Silhouette Score: {sil_score_periods:.4f}")
        print(f"Observed Calinski-Harabasz Score: {ch_score_periods:.4f}")
        
        # Permutation Test
        print(f"Running Permutation Test ({n_permutations} permutations)...")
        null_scores = []
        rng = np.random.default_rng(random_state)
        
        for _ in range(n_permutations):
            shuffled_labels = rng.permutation(period_labels)
            null_scores.append(silhouette_score(X_pca, shuffled_labels))
        
        null_scores = np.array(null_scores)
        mean_null_score = null_scores.mean()
        # p-value: proportion of random scores >= observed score
        p_value = (np.sum(null_scores >= sil_score_periods) + 1) / (n_permutations + 1)
        
        metrics["p_value"] = p_value
        metrics["mean_null_silhouette"] = mean_null_score
        
        print(f"Mean Null Silhouette Score: {mean_null_score:.4f}")
        print(f"P-value (Periods vs Random): {p_value:.4f}")
        if p_value < 0.05:
            print(">> Result: The temporal periods form statistically significant clusters in 3D space.")
        else:
            print(">> Result: The temporal periods do NOT form statistically significant clusters.")
            
    else:
        print("Skipping period evaluation (only 1 period found).")

    # 4. Plotting (Dendrogram Only)
    sns.set_context("paper", font_scale=1.8)
    fig, ax = plt.subplots(figsize=(12, 8))
    # Dendrogram (Hierarchical Clustering)
    Z = linkage(X_pca, method='ward')
    dendrogram(
        Z, 
        labels=centuries,
        ax=ax,
        # leaf_rotation=90.,
        # leaf_font_size=12.,
    )
    ax.set_title(f"Hierarchical Clustering Dendrogram - {file_tag.replace('_', ' ').title()}")
    ax.set_xlabel("Century")
    ax.set_ylabel("Distance")
    ax.tick_params(axis='both', which='major', labelsize=14)
    
    plt.tight_layout()
    
    if save_dir:
        save_dir = Path(save_dir)
        save_dir.mkdir(parents=True, exist_ok=True)
        out_path = save_dir / f"dendrogram_{file_tag}.png"
        plt.savefig(out_path, dpi=500, bbox_inches='tight')
        print(f"\nPlot saved to {out_path}")
        
    if show_plot:
        plt.show()
    else:
        plt.close()

    return metrics
